# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, too-many-locals, too-many-arguments
"""Schedule for bitserial dense operator."""
from __future__ import absolute_import as _abs
import tvm
from tvm import te
from tvm import autotvm
from tvm.topi.util import get_const_int, get_const_tuple
from .. import tag
from ..nn.bitserial_util import bitpack, binary_op_multiplier

@autotvm.register_topi_compute('bitserial_dense.x86')
def bitserial_dense(cfg, data, weight, data_bits, weight_bits, pack_dtype='uint32',
                    out_dtype='int16', unipolar=True):
    """Bitserial dense implementation. TODO: Why are these separate

    Parameters
    ----------
    data : tvm.te.Tensor
        2-D with shape [batch, in_dim]
    weight : tvm.te.Tensor
        2-D with shape [out_dim, in_dim] or
        3-D with shape [out_dim, weight_bits, in_dim]
    Returns
    -------
    output : tvm.te.Tensor
        2-D with shape [batch, out_dim]
    """
    data_packed = bitpack(data, data_bits, pack_axis=1, bit_axis=1, pack_type=pack_dtype)
    if len(weight.shape) == 2:
        weight_packed = bitpack(weight, weight_bits, pack_axis=1, bit_axis=1, pack_type=pack_dtype)
    else:
        weight_packed = weight
    Y, DB, K = get_const_tuple(data_packed.shape)
    X, WB, _ = get_const_tuple(weight_packed.shape)
    ######## Search space
    x, y = cfg.axis(X), cfg.axis(Y)
    db, wb, k = cfg.reduce_axis(DB), cfg.reduce_axis(WB), cfg.reduce_axis(K)
    ko, ki = cfg.define_split('tile_k', k, num_outputs=2)
    yo, yi = cfg.define_split('tile_y', y, num_outputs=2)
    xo, xi = cfg.define_split('tile_x', x, num_outputs=2)

    cfg.define_reorder('reorder_0', [yo, xo, ko, yi, wb, db, ki, xi],
                       policy='candidate', candidate=[
                           [yo, xo, ko, yi, wb, db, ki, xi],
                           [yo, xo, yi, ko, wb, db, ki, xi]])

    cfg.define_annotate('ann_reduce', [db, wb], policy='try_unroll')
    cfg.define_annotate('ann_spatial', [yi, xi], policy='try_unroll_vec')

    ###### Compute rule
    VX = cfg['tile_x'].size[-1]

    wvshape = (X//VX, WB, VX, K)
    oshape = (Y, X)

    k = te.reduce_axis((0, K), name='k')
    db = te.reduce_axis((0, DB), name='db')
    wb = te.reduce_axis((0, WB), name='wb')

    # Tile data and weights
    weight_vec = te.compute(wvshape, lambda xo, wb, vx, k:
                            weight_packed[xo*VX+vx][wb][k], name='weight_vec')

    idxdiv = tvm.tir.indexdiv
    idxmod = tvm.tir.indexmod

    matmul_unipolar = te.compute(oshape, lambda i, j: te.sum(
        (tvm.tir.popcount(weight_vec[idxdiv(j, VX), wb, idxmod(j, VX), k] & data_packed[i, db, k]) -
         tvm.tir.popcount(~weight_vec[idxdiv(j, VX), wb, idxmod(j, VX), k] & data_packed[i, db, k])
         ).astype(out_dtype)
        << (db+wb).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense_unipolar')

    matmul = te.compute(oshape, lambda i, j: te.sum(
        tvm.tir.popcount(weight_vec[idxdiv(j, VX), wb, idxmod(j, VX), k] & data_packed[i, db, k]
                         ).astype(out_dtype)
        << (db+wb).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense')

    # binary ops
    cfg.add_flop(2 * Y * X * K * binary_op_multiplier(pack_dtype))

    if unipolar:
        return matmul_unipolar
    return matmul

@autotvm.register_topi_schedule('biserial_dense.x86')
def schedule_bitserial_dense(cfg, outs):
    """Schedule for bitserial_dense.

    Parameters
    ----------
    outs: Array of Tensor
        The computation graph description of bitserial dense operator.
        in the format of an array of tensors.

    Returns
    -------
    s: Schedule
        The computation schedule for bitserial_dense.
    """
    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
    s = te.create_schedule([x.op for x in outs])

    def _schedule(cfg, s, data_vec, weight_vec, output):
        s[data_vec].parallel(s[data_vec].op.axis[0])
        s[weight_vec].parallel(s[weight_vec].op.axis[0])

        y, x = s[output].op.axis
        wb, db, k = s[output].op.reduce_axis

        yo, yi = cfg["tile_y"].apply(s, output, y)
        xo, xi = cfg["tile_x"].apply(s, output, x)
        ko, ki = cfg["tile_k"].apply(s, output, k)


        cfg["reorder_0"].apply(s, output, [yo, xo, ko, yi, wb, db, ki, xi])
        cfg["ann_reduce"].apply(s, output, [db, wb],
                                axis_lens=[get_const_int(db.dom.extent),
                                           get_const_int(wb.dom.extent)],
                                max_unroll=8,
                                cfg=cfg)
        cfg["ann_spatial"].apply(s, output, [yi, xi],
                                 axis_lens=[cfg['tile_y'].size[-1],
                                            cfg['tile_x'].size[-1]],
                                 max_unroll=8,
                                 cfg=cfg)
        s[output].vectorize(xi)
        s[output].parallel(yo)
        return s

    def traverse(op):
        """Internal traverse function"""
        # inline all one-to-one-mapping operators except the last stage (output)
        if tag.is_broadcast(op.tag) or 'elemwise' in op.tag:
            if op not in s.outputs:
                s[op].compute_inline()
            for tensor in op.input_tensors:
                if isinstance(tensor.op, tvm.te.ComputeOp):
                    traverse(tensor.op)

        elif op.tag == 'bitserial_dense' or 'bitserial_dense_unipolar':
            output = op.output(0)
            weight_vec = op.input_tensors[0]

            data_vec = op.input_tensors[1]
            data = data_vec.op.input_tensors[0]
            if "QuantizeInput" in data.op.name:
                data = data.op.input_tensors[0]
            _schedule(cfg, s, data_vec, weight_vec, output)
        else:
            raise RuntimeError("Unsupported operator: %s" % op.tag)

    traverse(outs[0].op)
    return s
